import torch

def test_fn():
    a = torch.tensor([[1,2,3],
                      [4,5,6],
                      [7,8,9]])

    print(a.shape)

    b0 = a.unsqueeze(0)
    print(b0.shape)

    b1 = a.unsqueeze(1)
    print(b1.shape)

if __name__ == "__main__":
    test_fn()